{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Densenet.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "JIivNnSwyODH",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 461
        },
        "outputId": "514e38dd-86d7-49dc-c238-aeb814b574f8"
      },
      "source": [
        "import torch\n",
        "import torchvision.models as models\n",
        "import torch.nn as nn\n",
        "from torch.autograd import Variable\n",
        "import torch.nn.functional as F\n",
        "model_ft = models.densenet161(pretrained=True)\n",
        "x = torch.randn([1,3,640 , 480])\n",
        "\n",
        "model_ft.features._modules.items()\n",
        "features = [x]\n",
        "i=-2;\n",
        "for k, v in (model_ft.features._modules.items()): \n",
        "  print(i)\n",
        "  i+=1\n",
        "  features.append( v(features[-1]) )\n",
        "  # print(k)\n",
        "  # print(v)\n",
        "  # print(features[-1].shape)\n",
        "\n",
        "for i,x in enumerate(features):\n",
        "    print(i,x.shape)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "-2\n",
            "-1\n",
            "0\n",
            "1\n",
            "2\n",
            "3\n",
            "4\n",
            "5\n",
            "6\n",
            "7\n",
            "8\n",
            "9\n",
            "0 torch.Size([1, 3, 640, 480])\n",
            "1 torch.Size([1, 96, 320, 240])\n",
            "2 torch.Size([1, 96, 320, 240])\n",
            "3 torch.Size([1, 96, 320, 240])\n",
            "4 torch.Size([1, 96, 160, 120])\n",
            "5 torch.Size([1, 384, 160, 120])\n",
            "6 torch.Size([1, 192, 80, 60])\n",
            "7 torch.Size([1, 768, 80, 60])\n",
            "8 torch.Size([1, 384, 40, 30])\n",
            "9 torch.Size([1, 2112, 40, 30])\n",
            "10 torch.Size([1, 1056, 20, 15])\n",
            "11 torch.Size([1, 2208, 20, 15])\n",
            "12 torch.Size([1, 2208, 20, 15])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eoGzvGeB2mgp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "class UpSample(nn.Sequential):\n",
        "    def __init__(self, skip_input, output_features):\n",
        "        super(UpSample, self).__init__()        \n",
        "        self.convA = nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1)\n",
        "        self.leakyreluA = nn.LeakyReLU(0.2)\n",
        "        self.convB = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1)\n",
        "        self.leakyreluB = nn.LeakyReLU(0.2)\n",
        "\n",
        "    def forward(self, x, concat_with):\n",
        "        up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)\n",
        "        return self.leakyreluB( self.convB( self.leakyreluA(self.convA( torch.cat([up_x, concat_with], dim=1) ) ) )  )\n",
        "\n",
        "class Decoder(nn.Module):\n",
        "    def __init__(self, num_features=2208, decoder_width = 0.5):\n",
        "        super(Decoder, self).__init__()\n",
        "        features = int(num_features * decoder_width)\n",
        "\n",
        "        self.conv2 = nn.Conv2d(num_features, features, kernel_size=1, stride=1, padding=1)\n",
        "\n",
        "        self.up1 = UpSample(skip_input=features//1 + 384, output_features=features//2)\n",
        "        self.up2 = UpSample(skip_input=features//2 + 192, output_features=features//4)\n",
        "        self.up3 = UpSample(skip_input=features//4 +  96, output_features=features//8)\n",
        "        self.up4 = UpSample(skip_input=features//8 +  96, output_features=features//16)\n",
        "\n",
        "        self.conv3 = nn.Conv2d(features//16, 1, kernel_size=3, stride=1, padding=1)\n",
        "\n",
        "    def forward(self, features):\n",
        "        x_block0, x_block1, x_block2, x_block3, x_block4 = features[3], features[4], features[6], features[8], features[11]\n",
        "        x_d0 = self.conv2(x_block4)\n",
        "        x_d1 = self.up1(x_d0, x_block3)\n",
        "        x_d2 = self.up2(x_d1, x_block2)\n",
        "        x_d3 = self.up3(x_d2, x_block1)\n",
        "        x_d4 = self.up4(x_d3, x_block0)\n",
        "        return self.conv3(x_d4)\n",
        "\n",
        "class Encoder(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Encoder, self).__init__()       \n",
        "        import torchvision.models as models\n",
        "        self.original_model = models.densenet161( pretrained=True )\n",
        "\n",
        "    def forward(self, x):\n",
        "        features = [x]\n",
        "        for k, v in self.original_model.features._modules.items(): features.append( v(features[-1]) )\n",
        "        return features\n",
        "\n",
        "class Model(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Model, self).__init__()\n",
        "        self.encoder = Encoder()\n",
        "        self.decoder = Decoder()\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.decoder( self.encoder(x) )"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}